from tqdm import tqdm
import torchvision
import torch
import lpips
import PIL

from utils import load_resize_image


def psnr_loss(img0, img1):
    img0 = (img0 + 1) * 127.5 
    img1 = (img1 + 1) * 127.5 
    mse = torch.nn.functional.mse_loss(img0, img1)
    return 10 * torch.log10(255**2 / mse)


if __name__ == "__main__":
    with open("data/captions.txt", "r") as f:
        data = f.read().split("\n")[:-1]
        image_ids = [d.split(".jpg,")[0] for d in data]

    to_tensor = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(0.5, 0.5),
    ])

    lpips_loss = lpips.LPIPS(net='alex').to("cuda")

    results = {}
    for name in ["rec", "rec_org", "rec_nlt"]:
        dirs = ["020", "050", "100", "200"]
        dirs = ["050"]

        for image_id in tqdm(image_ids):
            results[image_id] = []

            img_org = to_tensor(load_resize_image(f"data/images/{image_id}.jpg", size=512)).to("cuda").unsqueeze(0)

            for d in dirs:
                img = to_tensor(load_resize_image(f"results/{name}{d}/{image_id}.png", size=512)).to("cuda").unsqueeze(0)

                lpips_score = lpips_loss(img_org, img).mean().item()
                psnr_score = psnr_loss(img_org, img).item()
                results[image_id].append([lpips_score, psnr_score])


        with open(f"results/{name}_lpips.csv", "w") as f:
            f.write("image_id," + ",".join(dirs) + "\n")

            for image_id, l in results.items():
                f.write(image_id + ",")
                f.write(",".join([str(l[i][0]) for i in range(len(dirs))]))
                f.write("\n")

        with open(f"results/{name}_psnr.csv", "w") as f:
            f.write("image_id," + ",".join(dirs) + "\n")

            for image_id, l in results.items():
                f.write(image_id + ",")
                f.write(",".join([str(l[i][1]) for i in range(len(dirs))]))
                f.write("\n")
